# -*- coding:utf-8 -*-

import tensorflow as tf
from functools import reduce
from config.glob.global_pool import global_pool

"""
Mnist数据集
每一个自定义的数据集解析器必须有load_ori_dataset和pre_process函数, 必须有形参config
"""


def resize_x_y(x, y, x_shape, y_shape):
    """
    dataset改变大小
    :return:
    """
    mul = reduce(lambda a, b: a * b, x_shape)
    if not mul == 784:
        raise Exception('input reshape不匹配，shape:{}乘积应为784'.format(x_shape))

    return x.reshape(x_shape), y.reshape(y_shape)


def handle(dataset_in):
    """
    自定义处理方法
    :param dataset_in:
    :return:
    """
    dataset_out = dataset_in.map(
        lambda x, y: tf.py_func(
            resize_x_y,
            inp=[
                x, y,
                global_pool.config.xs_shape,
                global_pool.config.ys_shape,
            ],
            Tout=[tf.float32, tf.float32]),
    )
    return dataset_out
